import json
import os

import mindspore
from cybertron import BertTokenizer, BertModel
from sklearn.metrics.pairwise import cosine_similarity
from mindspore import context
import numpy as np

context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
module_dir = os.path.dirname(__file__)


def get_match_question(input_encode):
    with open(os.path.join(module_dir, '../data/resource_sentence_encode.json')) as f:
        resource_sentences_encode = json.load(f)
        question_similarity = {}
        for k, v in resource_sentences_encode.items():
            similarity = cosine_similarity(
                [input_encode[1][0].asnumpy()], [np.asarray(v)])
            question_similarity[k] = similarity
        sorted_similarity = sorted(question_similarity.items(), key=lambda x: x[1], reverse=True)
        if not sorted_similarity or len(sorted_similarity) == 0:
            return None
        elif sorted_similarity[0][1] >= 0.7:
            return [sorted_similarity[0][0]]
        elif len(sorted_similarity) > 1 and sorted_similarity[0][1] < 0.7:
            return [sorted_similarity[0][0], sorted_similarity[1][0]]


def encode_sentence(input_sentence):
    tokenizer = BertTokenizer.load('sentence-transformers/all-MiniLM-L6-v2')
    model = BertModel.load('sentence-transformers/all-MiniLM-L6-v2')
    model.set_train(False)
    input_token = mindspore.Tensor([tokenizer.encode(input_sentence, add_special_tokens=True)], mindspore.int32)
    input_encode = model(input_token)
    return input_encode


def load_q_a_data(file_path):
    with open(file_path) as f:
        return json.load(f)
